/**
 * Copyright 2024-2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "built_model_ndk_impl.h"

#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_align.h"
#include "model_manager/general_model_manager/ndk/ndk_util/ndk_util.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_nncore.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_helper.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_options.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_builtmodel.h"
// inc
#include "infra/base/securestl.h"
#include "framework/common/types.h"
// src/framework/inc
#include "infra/base/assertion.h"
#include "base/common/file_util/file_util.h"
#include "framework/infra/log/log.h"
// src/framework
#include "model/built_model/customdata_util.h"
#include "model/aipp/aipp_input_converter.h"

#include <functional>
#include "securec.h"

namespace hiai {

BuiltModelNDKImpl::BuiltModelNDKImpl(std::shared_ptr<OH_NNCompilation> builtModel,
    std::shared_ptr<IBuffer> modelBuffer)
    : nnCompilation_(std::move(builtModel)), modelBuffer_(modelBuffer)
{
}

BuiltModelNDKImpl::~BuiltModelNDKImpl()
{
    nnExecutor_.reset();
}

std::shared_ptr<OH_NNCompilation> BuiltModelNDKImpl::GetNNCompilation()
{
    return nnCompilation_;
}

namespace {
std::string GetModelNamefromModel(const std::shared_ptr<IBuffer> buffer)
{
    if (buffer->GetSize() < sizeof(ModelFileHeader) || buffer->GetData() == nullptr) {
        FMK_LOGW("size of buffer < size of modelheader");
        return "";
    }

    std::string namePrefix = "default_";
    ModelFileHeader header = *reinterpret_cast<const ModelFileHeader*>(buffer->GetData());
    if (header.magic != MODEL_FILE_MAGIC_NUM) {
        return "";
    }
    int32_t nameLen = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(header.name)));
    std::string modelName = namePrefix +
            std::string(reinterpret_cast<const char*>(header.name), std::min(nameLen, 32));
    return modelName;
}

Status NNTensorDescToNDTensorDesc(const NN_TensorDesc* nntensorDesc, NDTensorDesc &ndTensorDesc)
{
    int32_t* shape = nullptr;
    size_t shapeLength = 0;
    Status retCode = HIAI_NDK_NNTensorDesc_GetShape(nntensorDesc, &shape, &shapeLength);
    HIAI_EXPECT_TRUE(retCode == SUCCESS);

    OH_NN_DataType dataType;
    retCode = HIAI_NDK_NNTensorDesc_GetDataType(nntensorDesc, &dataType);
    HIAI_EXPECT_TRUE(retCode == SUCCESS);

    OH_NN_Format format;
    retCode = HIAI_NDK_NNTensorDesc_GetFormat(nntensorDesc, &format);
    HIAI_EXPECT_TRUE(retCode == SUCCESS);

    const char* tensorName = nullptr;
    retCode = HIAI_NDK_NNTensorDesc_GetName(nntensorDesc, &tensorName);
    HIAI_EXPECT_TRUE(retCode == SUCCESS);

    for (size_t i = 0; i < shapeLength; i++) {
        ndTensorDesc.dims.push_back(shape[i]);
    }
    ndTensorDesc.dataType = static_cast<DataType>(HIAIAlign::ConvertNNDataTypeToHIAI(dataType));
    ndTensorDesc.format = static_cast<Format>(HIAIAlign::ConvertNNFormatToHIAI(format));
    ndTensorDesc.tensorName = tensorName == nullptr ? "" : std::string(tensorName);
    return SUCCESS;
}
}

// buffer需要用户自己释放内存
Status BuiltModelNDKImpl::SaveToExternalBuffer(std::shared_ptr<IBuffer>& buffer, size_t& realSize) const
{
#ifdef AI_SUPPORT_BUILT_MODEL_SAVE
    HIAI_EXPECT_NOT_NULL(nnCompilation_);
    const size_t MAX_External_Buffer_SIZE = 2147483648UL; // Max size of 2GB.
    HIAI_EXPECT_NOT_NULL(buffer);
    HIAI_EXPECT_NOT_NULL(buffer->GetData());
    HIAI_EXPECT_TRUE(buffer->GetSize() < MAX_External_Buffer_SIZE);

    size_t offset = 0;
    if (!customModelData_.type.empty()) {
        HIAI_EXPECT_EXEC(CustomDataUtil::CopyCustomDataToBuffer(buffer, offset, customModelData_));
    }

    if (HIAI_NDK_NNCompilation_ExportCacheToBuffer(nnCompilation_.get(),
        reinterpret_cast<void*>(reinterpret_cast<char*>(buffer->GetData()) + offset), buffer->GetSize() - offset,
        &realSize) == SUCCESS) {
        realSize += offset;
        return SUCCESS;
    }
    return FAILURE;
#else
    (void)buffer;
    (void)realSize;
    FMK_LOGE("Not supported.");
    return FAILURE;
#endif
}

// 申请的内存不需要用户释放
Status BuiltModelNDKImpl::SaveToBuffer(std::shared_ptr<IBuffer>& buffer) const
{
#ifdef AI_SUPPORT_BUILT_MODEL_SAVE
    HIAI_EXPECT_NOT_NULL(nnCompilation_);
    size_t size = 200 * 1024 * 1024;    // size 为 200M = 200 * 1024 * 1024
    std::unique_ptr<uint8_t[]> tempData(new (std::nothrow) uint8_t[size]());
    if (tempData == nullptr) {
        return FAILURE;
    }
    size_t realSize = 0;

    // 内部会做memset操作，data需要先申请好内存，但是大小需要保存后才知道
    Status status = HIAI_NDK_NNCompilation_ExportCacheToBuffer(nnCompilation_.get(), tempData.get(), size, &realSize);
    if (status != SUCCESS) {
        return FAILURE;
    }

    auto* realData = new (std::nothrow) uint8_t[realSize];
    if (memcpy_s(realData, realSize, tempData.get(), realSize) != EOK) {
        delete[] realData;
        return FAILURE;
    }

    buffer = CustomDataUtil::SaveCustomDataToBuffer(realData, realSize, customModelData_);
    if (buffer == nullptr) {
        delete[] realData;
        return FAILURE;
    }

    return SUCCESS;
#else
    FMK_LOGE("Not supported.");
    return FAILURE;
#endif
}

Status BuiltModelNDKImpl::SaveToFile(const char* file) const
{
#ifdef AI_SUPPORT_BUILT_MODEL_SAVE
    HIAI_EXPECT_NOT_NULL(nnCompilation_);
    HIAI_EXPECT_NOT_NULL(file);
    HIAI_EXPECT_EXEC(FileUtil::CreateEmptyFile(file));

    std::shared_ptr<IBuffer> buffer;
    HIAI_EXPECT_EXEC(SaveToBuffer(buffer););

    Status status = FileUtil::WriteBufferToFile(buffer->GetData(), buffer->GetSize(), file);
    return status;
#else
    FMK_LOGE("Not supported.");
    return FAILURE;
#endif
}

Status BuiltModelNDKImpl::RestoreFromBuffer(const std::shared_ptr<IBuffer>& buffer)
{
    HIAI_EXPECT_NOT_NULL(buffer);
    HIAI_EXPECT_NOT_NULL(buffer->GetData());
    HIAI_EXPECT_TRUE(buffer->GetSize() > 0);
    HIAI_EXPECT_TRUE(nnCompilation_ == nullptr);

    const std::shared_ptr<IBuffer> outBuffer = CustomDataUtil::GetModelData(buffer, customModelData_);
    HIAI_EXPECT_NOT_NULL(outBuffer);
    modelBuffer_ = outBuffer;

    OH_NNCompilation* nnCompilation = HIAI_NDK_NNCompilation_ConstructWithOfflineModelBuffer(
        reinterpret_cast<void*>(reinterpret_cast<char*>(outBuffer->GetData())), outBuffer->GetSize());
    HIAI_EXPECT_NOT_NULL(nnCompilation);

    if (HIAI_NDK_NNCompilation_Build(nnCompilation) != SUCCESS) {
        HIAI_NDK_NNCompilation_Destroy(&nnCompilation);
        return FAILURE;
    }

    nnCompilation_.reset(nnCompilation,
        [](OH_NNCompilation* p) {
            HIAI_NDK_NNCompilation_Destroy(&p);
        });
    HIAI_EXPECT_NOT_NULL(nnCompilation_);

    modelName_ = GetModelNamefromModel(modelBuffer_);
    return SUCCESS;
}

Status BuiltModelNDKImpl::RestoreFromFile(const char* file)
{
    HIAI_EXPECT_TRUE(file != nullptr && std::string(file) != "");
    HIAI_EXPECT_TRUE(nnCompilation_ == nullptr);

    if (CustomDataUtil::HasCustomData(file)) {
        buffer_ = FileUtil::LoadToBuffer(file);
        HIAI_EXPECT_NOT_NULL(buffer_);

        std::shared_ptr<hiai::IBuffer> buffer =
            CreateLocalBuffer(static_cast<void*>(buffer_->MutableData()), buffer_->GetSize(), false);

        return RestoreFromBuffer(buffer);
    }

    modelPath_ = std::string(file);
    OH_NNCompilation* compilation = HIAI_NDK_NNCompilation_ConstructForCache();
    HIAI_EXPECT_NOT_NULL(compilation);
    HIAI_NDK_NNCompilation_SetCache(compilation, modelPath_.c_str(), 0);
    if (HIAI_NDK_NNCompilation_Build(compilation) != SUCCESS) {
        HIAI_NDK_NNCompilation_Destroy(&compilation);
        FMK_LOGE("HIAI_NDK_NNCompilation_Build failed");
        return FAILURE;
    }

    nnCompilation_.reset(compilation,
        [](OH_NNCompilation* p) {
            HIAI_NDK_NNCompilation_Destroy(&p);
        });
    HIAI_EXPECT_NOT_NULL(nnCompilation_);
    return SUCCESS;
}

Status BuiltModelNDKImpl::CheckCompatibility(bool& compatible) const
{
    // check model buffer
    HIAI_EXPECT_TRUE(modelBuffer_ != nullptr || !modelPath_.empty());

    HiAI_Compatibility compatibility = HIAI_COMPATIBILITY_COMPATIBLE;
    if (modelBuffer_ != nullptr) {
        compatibility = HIAI_NDK_HiAICompatibility_CheckFromBuffer(
            reinterpret_cast<void*>(reinterpret_cast<char*>(modelBuffer_->GetData())), modelBuffer_->GetSize());
    }

    if (!modelPath_.empty()) {
        compatibility = HIAI_NDK_HiAICompatibility_CheckFromFile(modelPath_.c_str());
    }

    compatible = (compatibility == HIAI_COMPATIBILITY_COMPATIBLE);
    return SUCCESS;
}

Status BuiltModelNDKImpl::CheckUpdatability(bool& updatable) const
{
    // 接口不支持
    FMK_LOGE("Not supported.");
    (void)updatable;
    return FAILURE;
}

Status BuiltModelNDKImpl::GetLibraryTimestamp(std::string& currentModelLibraryTimestamp,
    std::string& availableModelLibraryTimestamp) const
{
    // 接口不支持
    FMK_LOGE("Not supported.");
    (void)currentModelLibraryTimestamp;
    (void)availableModelLibraryTimestamp;
    return FAILURE;
}

std::vector<NDTensorDesc> BuiltModelNDKImpl::GetInputTensorDescs() const
{
    std::vector<NDTensorDesc> ndTensorDescs;
    if (nnCompilation_ == nullptr) {
        FMK_LOGE("please restore or build first.");
        return ndTensorDescs;
    }

    OH_NNCompilation* compilation = reinterpret_cast<OH_NNCompilation*>(nnCompilation_.get());
    if (nnExecutor_ == nullptr) {
        nnExecutor_ = std::shared_ptr<OH_NNExecutor>(HIAI_NDK_NNExecutor_Construct(compilation),
            [](OH_NNExecutor* p) { HIAI_NDK_NNExecutor_Destroy(&p); });
        HIAI_EXPECT_NOT_NULL_R(nnExecutor_, ndTensorDescs);
    }

    size_t inputCount = 0;
    Status retCode = HIAI_NDK_NNExecutor_GetInputCount(nnExecutor_.get(), &inputCount);
    if (retCode != SUCCESS || inputCount == 0) {
        FMK_LOGE("OH_NNExecutor_GetInputCount failed.");
        return ndTensorDescs;
    }

    for (size_t i = 0; i < inputCount; ++i) {
        NN_TensorDesc* nnTensorDesc = HIAI_NDK_NNExecutor_CreateInputTensorDesc(nnExecutor_.get(), i);
        if (nnTensorDesc == nullptr) {
            FMK_LOGE("OH_NNExecutor_CreateInputTensorDesc failed.");
            break;
        }
        NDTensorDesc ndTensorDesc;
        auto status = NNTensorDescToNDTensorDesc(nnTensorDesc, ndTensorDesc);
        HIAI_NDK_NNTensorDesc_Destroy(&nnTensorDesc);
        if (status != SUCCESS) {
            FMK_LOGE("NNTensorDescToNDTensorDesc failed.");
            break;
        }
        ndTensorDescs.push_back(ndTensorDesc);
    }
    if (ndTensorDescs.size() != inputCount) {
        FMK_LOGE("ndTensorDescs size mismatch.");
        ndTensorDescs.clear();
        return ndTensorDescs;
    }
    AippInputConverter::ConvertInputTensorDesc(customModelData_, ndTensorDescs);
    return ndTensorDescs;
}

std::vector<NDTensorDesc> BuiltModelNDKImpl::GetOutputTensorDescs() const
{
    std::vector<NDTensorDesc> ndTensorDescs;
    if (nnCompilation_ == nullptr) {
        FMK_LOGE("please restore or build first.");
        return ndTensorDescs;
    }

    OH_NNCompilation* compilation = reinterpret_cast<OH_NNCompilation*>(nnCompilation_.get());
    if (nnExecutor_ == nullptr) {
        nnExecutor_ = std::shared_ptr<OH_NNExecutor>(HIAI_NDK_NNExecutor_Construct(compilation),
            [](OH_NNExecutor* p) { HIAI_NDK_NNExecutor_Destroy(&p); });
        HIAI_EXPECT_NOT_NULL_R(nnExecutor_, ndTensorDescs);
    }

    size_t outputCount = 0;
    Status retCode = HIAI_NDK_NNExecutor_GetOutputCount(nnExecutor_.get(), &outputCount);
    if (retCode != SUCCESS || outputCount == 0) {
        FMK_LOGE("OH_NNExecutor_GetOutputCount failed.");
        return ndTensorDescs;
    }

    for (size_t i = 0; i < outputCount; ++i) {
        NN_TensorDesc* nnTensorDesc = HIAI_NDK_NNExecutor_CreateOutputTensorDesc(nnExecutor_.get(), i);
        if (nnTensorDesc == nullptr) {
            FMK_LOGE("OH_NNExecutor_CreateOutputTensorDesc failed.");
            break;
        }
        NDTensorDesc ndTensorDesc;
        auto status = NNTensorDescToNDTensorDesc(nnTensorDesc, ndTensorDesc);
        HIAI_NDK_NNTensorDesc_Destroy(&nnTensorDesc);
        if (status != SUCCESS) {
            FMK_LOGE("NNTensorDescToNDTensorDesc failed.");
            break;
        }
        ndTensorDescs.push_back(ndTensorDesc);
    }
    if (ndTensorDescs.size() != outputCount) {
        FMK_LOGE("ndTensorDescs size mismatch.");
        ndTensorDescs.clear();
        return ndTensorDescs;
    }
    return ndTensorDescs;
}

std::string BuiltModelNDKImpl::GetName() const
{
    if (nnCompilation_ == nullptr) {
        FMK_LOGE("please restore or build first.");
        return "";
    }

    return modelName_;
}

void BuiltModelNDKImpl::SetName(const std::string& name)
{
    if (nnCompilation_ == nullptr) {
        FMK_LOGE("please restore or build first.");
        return;
    }
    modelName_ = name;
}

void BuiltModelNDKImpl::SetCustomData(const CustomModelData& customModelData)
{
    customModelData_ = customModelData;
}

const CustomModelData& BuiltModelNDKImpl::GetCustomData()
{
    return customModelData_;
}

Status BuiltModelNDKImpl::GetTensorAippInfo(int32_t index, uint32_t* aippParaNum, uint32_t* batchCount)
{
    // 接口不支持
    FMK_LOGE("Not supported.");
    (void)index;
    (void)aippParaNum;
    (void)batchCount;
    return FAILURE;
}

Status BuiltModelNDKImpl::GetTensorAippPara(int32_t index, std::vector<std::shared_ptr<IAIPPPara>>& aippParas) const
{
    // 接口不支持
    FMK_LOGE("Not supported.");
    (void)index;
    (void)aippParas;
    return FAILURE;
}

std::shared_ptr<OH_NNExecutor> BuiltModelNDKImpl::GetExecutor() const
{
    return nnExecutor_;
}

void BuiltModelNDKImpl::ReSetExecutor(const std::shared_ptr<OH_NNExecutor> &executor)
{
    nnExecutor_ = executor;
}

uint64_t BuiltModelNDKImpl::GetFmMemorySize() const
{
    uint64_t fmSize = 0;
    OH_NNCompilation* compilation = reinterpret_cast<OH_NNCompilation*>(nnCompilation_.get());
    HIAI_EXPECT_NOT_NULL_R(compilation, 0);
    Status retCode = HIAI_NDK_HiAIBuiltmodel_GetFmMemorySize(compilation, &fmSize);
    HIAI_EXPECT_TRUE_R(retCode == SUCCESS, 0);
    return fmSize;
}

HIAI_M_API_EXPORT std::shared_ptr<IBuiltModel> CreateBuiltModelFromNDK()
{
    return make_shared_nothrow<BuiltModelNDKImpl>();
}
}